Risk-aware survival time prediction from whole slide pathological images

Deep-learning-based survival prediction can assist doctors by providing additional information for diagnosis by estimating the risk or time of death. The former focuses on ranking deaths among patients based on the Cox model, whereas the latter directly predicts the survival time of each patient. However, it is observed that survival time prediction for the patients, particularly with close observation times, possibly has incorrect orders, leading to low prediction accuracy. Therefore, in this paper, we present a whole slide image (WSI)-based survival time prediction method that takes advantage of both the risk as well as time prediction. Specifically, we propose to combine these two approaches by extracting the risk prediction features and using them as guides for the survival time prediction. Considering the high resolution of WSIs, we extract tumor patches from WSIs using a pre-trained tumor classifier and apply the graph convolutional network to aggregate information across these patches effectively. Extensive experiments demonstrate that the proposed method significantly improves the time prediction accuracy when compared with direct prediction of the survival times without guidance and outperforms existing methods.


Scientific Reports
| (2022) 12:21948 | https://doi.org/10.1038/s41598-022-26096-z www.nature.com/scientificreports/ basically a regression problem 25 . Therefore, it is important to predict survival times accurately and in the correct sequence. To this end, we propose to take advantage of both risk prediction and time prediction by allowing the estimated survival risk features to guide death ranking, thereby enabling the prediction of survival times with higher accuracy in the correct sequence. To better aggregate patch-level features to make WSI-level predictions, we construct a graph by connecting patches in each WSI. To eliminate as much interference (e.g., normal tissue regions which do not contribute to the diagnosis) as possible without extra manual annotation, instead of using all tissue patches within WSIs or randomly sampling patches, we select patches detected as cancer by a pretrained patch-level binary classifier. Subsequently, a graph convolutional network (GCN) is used to aggregate information across patches effectively. Our main contributions are threefold. First, we propose a two-branch GCN-based model that integrates risk and time prediction features to make more accurate predictions. To the best of our knowledge, this is the first work to use survival risk prediction as guidance to predict exact survival times of patients from WSIs. Second, we design our framework to predict survival times directly by effectively aggregating patch-level information through GCN. A distinct difference compared to existing survival time prediction works is the prediction is made on WSI-level while manual annotation of ROIs are not required. Third, through extensive experiments, we show that our model achieved superior performance and generalization ability on several publicly available datasets in terms of the mean absolute error for survival time prediction.

Related work
Automatic survival prediction from pathological images is of significant help to doctors in diagnosis. Based on recent advances in deep learning, learning-based methods are receiving significant attention. Here, we review literature in this field from three perspectives: (1) Prediction level (i.e. patch-level vs. WSI-level), (2) network (i.e. CNN vs. GCN), and (3) output (i.e. risk vs. survival time).
First, due to the extremely large image size of WSIs, most neural networks cannot handle WSIs directly. Therefore, one research direction is to explore the potential of neural networks for survival prediction using small-size patches 14,15,26 . Because multiple patches can be extracted from a single WSI, a training dataset consisting of a sufficient number of patches can be easily obtained, facilitating the research on survival prediction. To make it feasible for survival prediction using a single patch, Zhu et al. 14 manually acquired the ROI of size 1024 × 1024 for each WSI and performed a single-patch-based prediction using a deep convolutional neural network (CNN). Although promising results have been reported, such approaches require a significant amount of time and effort from experienced pathologists. Therefore, an alternative approach is to obtain and aggregate patch-level predictions from multiple patches in a WSI. For example, Zhu et al. 16 clustered multiple patches according to their phenotypes and selected distinctive ones for further aggregation to obtain a WSI-level prediction. Besides, there are also approaches to aggregate the patch-level features so that the prediction is directly made on WSI-level or patient-level 14,17,19,[27][28][29][30][31] . Yao et al. 17 improved patch-level prediction and aggregation by adopting the multiple instance learning framework. Chen et al. 28 formulated WSIs as a graph-based data structure in the Euclidean space similar to a point cloud such that the node features can be hierarchically aggregated from local to global structures.
Second, network architecture plays an essential role in survival prediction. Motivated by the success of CNNs for general computer vision tasks, early studies on survival prediction are only based on CNNs. Many widely-used CNN architectures, such as VGG 32 , ResNet 33 , and U-Net 34 have been adopted as baseline network architectures for feature extraction or aggregation of survival prediction 15,29,35 . Chang et al. 29 proposed to extract patch features using a pre-trained ResNet-50 33 , then construct the feature map from the patch features according to the location of the image patches in the WSI, and apply convolution operation to the generated feature map to predict the survival risk. However, for aggregating patch-level features or predictions to make WSI-level predictions, CNNs are only applicable to regularly structured patches such as tiles 19 . In contrast, GCN 36 can operate on any graph structure defined from the patches in a WSI and effectively aggregate information across the vertices via graph convolution; thus, it is an alternative design of the CNNs. Li et al. 19 selected random patches from a WSI and connected them according to Euclidean distances. A graph attention mechanism was applied to aggregate information from important patches among randomly chosen patches. Wang et al. 27 sampled cancerous patches from a WSI using a patch-level CNN and applied nuclei instance segmentation to each sampled patch to help the survival prediction model learn the hierarchical graph representations.
Third, the type of survival prediction output also needs consideration in clinical practice. Cox-model-based methods have been extensively studied and have achieved outstanding performance in predicting the survival risk of patients 14,[16][17][18][19][27][28][29] . However, predicted risk values can only be used to compare the death rankings of patients, which is insufficient for estimating the survival time of each patient directly. Although further estimation of the baseline hazard function can enable survival time prediction, the baseline hazard function is an unspecified function 22 , leading to a variety of ways to predict it 6,23,24 . Meanwhile, Xiao et al. 15 formulated the patch-level direct survival time prediction as an ordinal regression problem and introduced a censoring-aware loss to deal with the censored data. However, survival time prediction is still conducted at the patch-level, which requires the manual selection of patches.
In summary, WSI-level prediction is more desirable than patch-level prediction for practical diagnosis because of its ability to capture both local and global information without manual ROI annotation. Additionally, to aggregate information across different numbers and positions of patches effectively, GCN is a preferable choice for backbone architecture design. Finally, compared to risk prediction, direct time prediction is more straightforward and helpful for diagnosis. Based on these reasons, we propose to directly predict survival time from WSIs using the GCN backbone.

Results
Datasets. We used the public cancer datasets provided by The Cancer Genome Atlas (TCGA) as our experimental datasets. Our experiments were conducted on three bladder, breast, and brain cancer subtypes in TCGA: bladder urothelial carcinoma (BLCA), breast invasive carcinoma (BRCA), and glioblastoma multiforme (GBM). Each dataset contained high-resolution WSIs of patients with a wide range of observation times.
We first extracted patches of size 512 × 512 for each WSI (for the 20× magnification scale) and then applied our pre-trained tumor classifier to select tumor patches. We then excluded WSIs that did not contain any patches classified as tumors. Additionally, WSIs with exceptionally large observation times (more than 2000 days for BLCA and GBM or 4000 days for BRCA) were removed, resulting in the three datasets summarized in Table 1. Each dataset was divided into training (70%) and test (30%) sets, and we further split 30% of samples in the training set as the validation set. These datasets were used for both the RP-GCN and TP-GCN. Implementation details. For graph construction, the node features extracted by ResNet-18 were of size 512 × 1 . The number of neighbors k was set to 8, and the distance threshold d was set to 2560 pixels at the 20× magnification scale, which is the total length of five patches. The pooling ratios r for the three GCN blocks were set to 0.6, 0.6, and 0.5, respectively.
We trained the RP-GCN for 300 epochs using the Adam optimizer 37 and set the initial learning rate to 1e−4, which was halved after 10, 30, and 50 epochs. For the training of the TP-GCN, we quantized the observation time into several intervals such that each interval contained the same number of training samples 15 . This partitioning led to balanced training samples for the N classifiers, where the number of training samples in each interval was seven in our experiments. The initial learning rate of TP-GCN was set to 5e−4 while the other settings were the same as those for RP-GCN. We adopted ReLu and tanh as the activation functions σ for RP-GCN and TP-GCN, respectively.
Performance evaluation. Because the proposed model consists of RP-GCN and TP-GCN, we first evaluated the performance of RP-GCN to confirm that the predicted survival risk could be beneficial for the subsequent survival time prediction task. We then evaluated the performance of TP-GCN to verify the effectiveness of the proposed model. Several ablation studies were also conducted to demonstrate the necessity of each component of the proposed model. All experiments on the proposed and compared methods were conducted using the same datasets listed in Table 1.

Survival risk prediction.
We used the C-index 38 as our evaluation metric. The C-index evaluates how well the death ranking among patients is organized, and is defined as follows: where C r represents the C-index of risk prediction, K indicates the total number of comparable pairs, and P is the total number of patients. I a is the indicator function, i.e., I a = 1 if a is true, and I a = 0 otherwise. During the training stage, each WSI was considered an individual data sample. During the validation and testing stages, if there were several WSIs for one patient, we considered the average value of the WSI-level predictions as the patient-level prediction result. The test results are listed in the first row of Table 3. The C-indexes of risk prediction for BLCA, BRCA, and GBM were obtained as 0.834, 0.627, and 0.563, respectively.

Survival time prediction.
For the performance evaluation of TP-GCN, we used the mean absolute error (MAE) between the ground truth and predicted survival time, following 15 . Specifically, for each patient with a censoring status δ , the predicted survival time τ , and observation time τ , the MAE can be calculated as follows: We also evaluated the pairwise ranking consistency between the predicted survival times by calculating the C-index as follows: (3)  15 . We conducted the performance comparisons using the same experimental settings introduced in the corresponding studies and the same datasets used in our experiments. For all compared methods, we selected the model with the best performance on the validation set for evaluating the performance on the test set.
Because the other methods, except for CDOR 15 , can only predict survival risk, we followed the approach in 15 , which converts the survival risks predicted by different methods into time results using the Python package lifelines 39 . This package can automatically estimate the baseline hazard function h 0 (τ ) , and consequently, the hazard function h(τ |x x x) of the Cox model can be calculated with the predicted survival risk. The survival function can then be estimated, and finally, the survival time is calculated using two approaches: (1) taking the expectation of the survival function, termed as "Expect", and (2) thresholding the survival function, termed as "Thresh. " For CDOR, because the original method required manually selected patches of size 1024 × 1024 , which are not publicly available, we trained another tumor classifier using the Camelyon17 dataset 40 with the patch size 1024 × 1024 in the same manner as the patch classifier used to construct a graph structure in our experiments. Then, for each WSI, we selected the patch with the highest tumor probability and obtained the performance from the selected patch. The survival time prediction results using the proposed model and other methods are presented in Table 2.
As can be observed, the MAEs of the proposed model were obtained as 123.2, 167.5, and 303.3 days for BLCA, BRCA, and GBM, respectively, outperforming the existing approaches on BLCA and BRCA. Especially for BLCA, the MAE was significantly lower (-30.3 days compared to the second-lowest one obtained by DeepGraphSurv 19 ) than the other approaches, while the C-index of the time-prediction results was also higher (+0.029 compared to the second-highest one obtained by DeepGraphSurv 19 ) than the other approaches. Figure 1 shows the training and validation loss per epoch on BLCA, demonstrating that no overfitting problem arose. For BRCA, one can see that even though the C-index obtained from DeepAttnMISL was the highest among all approaches, its survival time prediction accuracy is lower than the one obtained from the proposed model. This is because more complicated procedures, such as the estimation of the baseline hazard function, are required to convert the survival risk results into survival time results. As for GBM, the result was slightly worse than DeepAttnMISL and PatchGCN; we considered this might be due to the significant difference between the number of patients and WSIs in GBM (57 and 150). The two risk-prediction approaches both adopted early fusion strategies to aggregate the information of multiple WSIs from the same patient, and consequently, the training and prediction were both performed directly on the patient-level. However, the proposed model was trained on the WSI-level, and the final patient-level results were obtained by averaging the WSI-level results of one patient.
Ablation study. We conducted an ablation study to verify the effectiveness of the risk feature h h h r . To this end, we directly applied ordinal regression to the time feature h h h t extracted by the TP-GCN and calculated the MAE and C-index of the obtained results. We also reported the C-index of the survival risk prediction results. As shown in Table 3, the C-index obtained from the results using the time feature alone for BLCA, BRCA, and GMB was 0.808, 0.565, and 0.542, respectively, indicating that the survival ranking between patients was not sufficiently handled. When the risk feature was involved, further ranking information could be provided to guide the time prediction, leading to a higher C-index. Meanwhile, the MAEs for the TP-GCN were obtained as 182.4, 178.7, and 336.6 for the three datasets, respectively. However, for the proposed model, because the ranking was also considered during time-prediction training, the MAEs of the TP-GCN were reduced by 59.2, 11.2, and 33.3 days for the three datasets, respectively. Figure 2 visualizes the results on the BLCA test set. Ideally, the closer the point to the dotted diagonal line, the better the performance. Because any points above the line are all correct predictions for the censored data, we excluded the censored data for this figure to better understand the results. When only the time feature h h h t was Table 2. Results of the proposed and existing methods. Because DeepAttnMISL 17 , PatchGCN 28 , and DeepGraphSurv 19 are originally survival risk prediction methods, we used the Python package lifelines 39 to convert the survival risk into survival time for evaluation. The evaluation was conducted on both time accuracy (MAE in days) and ranking accuracy (C-index). *Indicates the survival time results calculated by lifelines.

BLCA
BRCA GBM      www.nature.com/scientificreports/ used, the results were close to the ground truth overall. However, as shown in Fig. 2a, the time prediction results inside the cyan circle are very close to the ones inside the blue circle. In other words, although the prediction results are reasonably close to the ground truth, their death rankings are inaccurate. Notice the color of these points, i.e., the predicted risk. The predicted risks of the data inside the blue circle are lower than the ones in the cyan circle. When the risk feature h h h r was combined for time prediction, it guided the model to calibrate time predictions such that the death rankings among patients could be better preserved. As a result, the time predictions inside the blue circle were increased, whereas the ones inside the cyan circle were decreased, as shown in Fig. 2b. Note that the final time prediction results are closer to the ground truth with more correct death rankings. In addition, we conducted the experiment by averaging the risk feature h h h r and time feature h h h t rather than concatenating them. As shown in Table 4, the averaging resulted in a slight improvement on BRCA but degradation on BLCA and GBM compared with concatenation; thus, we adopted concatenation as the feature fusion strategy in our experiments.

MAE C-index MAE C-index MAE C-index Expect Thresh Risk Time Expect Thresh Risk Time Expect Thresh Risk Time
Instead of pre-training the RP-GCN and using it as a frozen risk feature extractor, we also conducted an ablation study by jointly training the RP-GCN and TP-GCN. As can be seen from Table 4, jointly training these two networks leads to higher MAEs and lower C-index scores. We emphasize that RP-GCN is intended to provide more accurate sequential guidance for time prediction; thus, training two networks simultaneously will make RP-GCN optimized not only by the cox loss but also a set of cross-entropy losses, which is uncorrelated to ranking prediction.
We further visualized the position of nodes that remained after the last SagPool 41 operation. Since the Sag-Pool 41 operation is based on self-attention mechanism, the remaining nodes can be naturally considered as a high-attention area of RP-GCN and TP-GCN. As can be observed from Fig. 4, RP-GCN and TP-GCN showed distinct differences in areas of high attention. In other words, the time feature and risk feature convey complementary information for more accurate predictions.

Conclusion and discussion
In this paper, we presented a novel approach for accurate survival time prediction that addresses this problem from two perspectives. First, the information available in a WSI needs to be aggregated globally without any intervention. Second, in addition to estimating the survival times of patients based on the characteristics of the patients themselves, the characteristics and survival times of other patients can also be used as comparison objects to estimate the survival times of target patients comprehensively. Considering these two perspectives, we proposed a two-branch GCN-based model that exploits the Cox model to capture the ranking relationships among patients in the form of the risk feature. The risk feature was then concatenated with the time feature to guide the model in making more precise time predictions with correct rankings. Experimental results on three public datasets demonstrated that the proposed risk-aware prediction method yielded more accurate survival time prediction results compared to its risk-agnostic version.
Several directions for future studies should be considered. First, we pre-trained the RP-GCN and fixed it during the training of the TP-GCN because our primary objective was a precise survival time prediction. For certain applications in which survival risk is more important, we expect that the proposed model can be trained in the opposite manner by using the time feature as a guide for risk prediction. Second, we used a pre-trained patch classifier to extract tumor patches for the graph construction. End-to-end training from the WSIs to survival time predictions is a challenging but promising direction for our future work. Last, our framework considered each WSI as an individual during training, and the final patient-level prediction was performed by taking the average of WSI-level predictions. We plan to apply an early fusion of WSI-level information to enable more precise patient-level predictions.

Methods
Problem formulation and motivation. In general, an instance of survival data can be represented as {x x x, τ , δ} , where x x x is the feature vector of the patient and δ is the censoring status indicator, i.e., 1 for uncensored (death observed) data and 0 for censored data. τ is the last observation time for censored data or the survival time for uncensored data. The goal of survival time prediction is to predict τ by using x x x and δ.
One of the most widely used methods for survival prediction is the construction of the Cox proportional hazards model 21 . The hazard function, which assesses the relationship between the distribution of failure time (death time for survival prediction) and x x x 21 , is defined as h(τ |x x x) = h 0 (τ )exp(β β β T x x x) , where h 0 (τ ) is the baseline hazard function; and β β β is a regression parameter vector. Katzman et al. 5 first proposed a deep-learning-based method to directly predict the survival risk R (=β β β T x x x ) using a fully connected network, demonstrating its superiority over the conventional regression-based methods. The survival risk R can be estimated by minimizing the negative log partial likelihood l(R), which is defined as follows: where M is the number of samples, and the subscript i represents the data index. l(R) has been frequently used as the loss function for training survival risk prediction models 14,[16][17][18][19][27][28][29] .
The Cox proportional hazards model includes the baseline hazard function, which is very challenging to estimate 15 . Therefore, an alternative method is to estimate the survival time τ directly from the feature vector x x x 15 . However, it does not consider that the risk orders among patients convey vital information 44 . Therefore, we propose to use the risk prediction features as guidance for the survival time prediction model such that the Proposal overview. The overall workflow of the proposed method is shown in Fig. 3. For a given WSI, we first apply a graph construction to generate the graph representation of the WSI, which will be described in "5.3" Graph construction section. Two GCNs are then utilized, namely the risk-prediction GCN (RP-GCN) and timeprediction GCN (TP-GCN), where we adopt the network architecture from 36 . While the former is pre-trained to predict the survival risk using the Cox model and frozen, the latter is trained by concatenating features from the former to guide survival time prediction. Specifically, given the concatenated risk-time feature vector, ordinal  Graph construction. Before constructing the graph, the background regions are removed using the method proposed in 42 . The remaining foreground tissue regions are cropped into non-overlapping patches of size 512 × 512 (for the 20× magnification scale). To avoid the need for manual intervention by medical experts, a patch-level binary (normal/tumor) classifier is used to extract tumor-like patches from a WSI automatically. We adopt ResNet-18 33 as the backbone for our classifier and train it using the Camelyon17 dataset 40 which contains WSIs with pixel-level ROI annotations. The result is a set of tumor patches ρ ρ ρ = {ρ|f (ρ) > 0.5} , where f is the tumor classifier that outputs the tumor probability of an input patch ρ . Then, the features of each patch in ρ ρ ρ are extracted by an ImageNet 43 -pre-trained ResNet-18 33 , and they are treated as vertices in a graph, and edges between vertices are established by the k-nearest neighbors algorithm using the Euclidean distance between patch center positions with a spatial distance threshold d because the patches closer each other are more likely to interact 11 . The constructed graph can be represented as G = (V , E) , where V is a set of vertices, and E is the set of edges in the graph. The graph construction procedure is illustrated in Fig. 5.
where e ij ∈ E indicates the edge between two vertices v i and v j , D(, ) the Euclidean distance between the center positions of two patches, and knn(v i ) indicates the k-nearest neighbors of v i .
Risk-prediction GCN. The generated graph G is first used for the RP-GCN shown in Fig. 6, which contains three GCN blocks and a single GCN layer, followed by a global max pooling, a global average pooling, and two fully connected layers. A GCN block can be expressed as follows: where G represents the graph convolution operation 36 for one GCN layer, in which A is the adjacency matrix, and Ã = A + I is the adjacency matrix including self-loops. D is the degree matrix, where its diagonal entry is given as D ii = jÃij , H (l) is the feature matrix at the l-th layer with H (0) = V , and W (l) is the trainable parameters for layer l. σ is an activation function, and SagPool represents the self-attention graph pooling 41 with a pooling ratio r. Following the ordinal regression framework 15 , we divide the range of survival time into N intervals, i.e., {T 0 , T 1 , ..., T N } with T 0 < T 1 < ... < T N . The intervals are empirically chosen to evenly distribute patients' survival time in the training database. We then define a binary label for each patient as follows: where τ is the survival time (or the last observation time for the censored data), and n is the index of the time interval. In this manner, the survival time prediction is transformed into an ordinal regression problem aiming to solve N binary classification where the n-th classifier predicts whether τ is greater than T n . We adopt the censoring-aware loss 15 so that both uncensored and censored data can be involved during training. The loss function is given as: where B B B = {b 1 , b 2 , · · ·, b N } and δ δ δ = {δ 1 , δ 2 , · · ·, δ N } . The output of the n-th classifier is denoted as o n , and o o o = {o 1 , o 2 , · · ·, o N } . This means that for uncensored training samples, the cross-entropy is calculated for all b n . For censored training samples, however, the last observation time is still valuable for training. Therefore, we ignore the intervals where b n = 0 and the loss is measured only when b n = 1.
In the testing stage, the final survival prediction time τ is calculated as: where I o n = 1 if o n ≥ 0.5 , and I o n = 0 otherwise.

Data availability
The datasets generated during and/or analysed during the current study are available from the corresponding author on reasonable request.

Code availability
The code is available to editors and reviewers upon request. The code will be publicly available at https://github. com/ZXXu96/RAST upon acceptance. 2 , Figure 6. Illustration of RP-GCN. It contains four graph convolution layers and three self-attention graph pooling 41 layers, followed by global max and mean pooling layers. Two fully connected layers are further applied for the survival risk estimation using the Cox model 21 . Note that these two fully connected layers are detached after the RP-GCN training is finished.